{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "38951745"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from keras.layers import Input, Conv2D, Lambda, merge, Dense, Flatten,MaxPooling2D\n",
    "from keras.models import Model, Sequential\n",
    "from keras.regularizers import l2\n",
    "from keras import backend as K\n",
    "from keras.optimizers import SGD,Adam\n",
    "from keras.losses import binary_crossentropy\n",
    "import numpy.random as rng\n",
    "import numpy as np\n",
    "import os\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from sklearn.utils import shuffle\n",
    "%matplotlib inline\n",
    "def W_init(shape,name=None):\n",
    "    \"\"\"Initialize weights as in paper\"\"\"\n",
    "    values = rng.normal(loc=0,scale=1e-2,size=shape)\n",
    "    return K.variable(values,name=name)\n",
    "#//TODO: figure out how to initialize layer biases in keras.\n",
    "def b_init(shape,name=None):\n",
    "    \"\"\"Initialize bias as in paper\"\"\"\n",
    "    values=rng.normal(loc=0.5,scale=1e-2,size=shape)\n",
    "    return K.variable(values,name=name)\n",
    "\n",
    "input_shape = (105, 105, 1)\n",
    "left_input = Input(input_shape)\n",
    "right_input = Input(input_shape)\n",
    "#build convnet to use in each siamese 'leg'\n",
    "convnet = Sequential()\n",
    "convnet.add(Conv2D(64,(10,10),activation='relu',input_shape=input_shape,\n",
    "                   kernel_initializer=W_init,kernel_regularizer=l2(2e-4)))\n",
    "convnet.add(MaxPooling2D())\n",
    "convnet.add(Conv2D(128,(7,7),activation='relu',\n",
    "                   kernel_regularizer=l2(2e-4),kernel_initializer=W_init,bias_initializer=b_init))\n",
    "convnet.add(MaxPooling2D())\n",
    "convnet.add(Conv2D(128,(4,4),activation='relu',kernel_initializer=W_init,kernel_regularizer=l2(2e-4),bias_initializer=b_init))\n",
    "convnet.add(MaxPooling2D())\n",
    "convnet.add(Conv2D(256,(4,4),activation='relu',kernel_initializer=W_init,kernel_regularizer=l2(2e-4),bias_initializer=b_init))\n",
    "convnet.add(Flatten())\n",
    "convnet.add(Dense(4096,activation=\"sigmoid\",kernel_regularizer=l2(1e-3),kernel_initializer=W_init,bias_initializer=b_init))\n",
    "\n",
    "#call the convnet Sequential model on each of the input tensors so params will be shared\n",
    "encoded_l = convnet(left_input)\n",
    "encoded_r = convnet(right_input)\n",
    "#layer to merge two encoded inputs with the l1 distance between them\n",
    "L1_layer = Lambda(lambda tensors:K.abs(tensors[0] - tensors[1]))\n",
    "#call this layer on list of two input tensors.\n",
    "L1_distance = L1_layer([encoded_l, encoded_r])\n",
    "prediction = Dense(1,activation='sigmoid',bias_initializer=b_init)(L1_distance)\n",
    "siamese_net = Model(inputs=[left_input,right_input],outputs=prediction)\n",
    "\n",
    "optimizer = Adam(0.00006)\n",
    "#//TODO: get layerwise learning rates and momentum annealing scheme described in paperworking\n",
    "siamese_net.compile(loss=\"binary_crossentropy\",optimizer=optimizer)\n",
    "\n",
    "siamese_net.count_params()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data \n",
    "The data is pickled as an N_classes x n_examples x width x height array, and there is an accompanyng dictionary to specify which indexes belong to which languages."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training alphabets\n",
      "dict_keys(['Hebrew', 'Ojibwe_(Canadian_Aboriginal_Syllabics)', 'Japanese_(katakana)', 'Arcadian', 'Asomtavruli_(Georgian)', 'Syriac_(Estrangelo)', 'Grantha', 'Tagalog', 'Gujarati', 'Tifinagh', 'Latin', 'Bengali', 'Inuktitut_(Canadian_Aboriginal_Syllabics)', 'Alphabet_of_the_Magi', 'Early_Aramaic', 'Braille', 'Korean', 'Greek', 'Futurama', 'Blackfoot_(Canadian_Aboriginal_Syllabics)', 'Anglo-Saxon_Futhorc', 'Mkhedruli_(Georgian)', 'Sanskrit', 'Malay_(Jawi_-_Arabic)', 'Cyrillic', 'Balinese', 'Burmese_(Myanmar)', 'Armenian', 'Japanese_(hiragana)', 'N_Ko'])\n",
      "validation alphabets:\n",
      "dict_keys(['Sylheti', 'Ge_ez', 'Atlantean', 'Gurmukhi', 'Manipuri', 'Syriac_(Serto)', 'Avesta', 'Mongolian', 'Kannada', 'Oriya', 'Angelic', 'Old_Church_Slavonic_(Cyrillic)', 'Malayalam', 'Keble', 'Glagolitic', 'Tengwar', 'Atemayar_Qelisayer', 'ULOG', 'Aurek-Besh', 'Tibetan'])\n"
     ]
    }
   ],
   "source": [
    "PATH = \"/home/soren/Desktop/keras-oneshot\" #CHANGE THIS - path where the pickled data is stored\n",
    "\n",
    "with open(os.path.join(PATH, \"train.pickle\"), \"rb\") as f:\n",
    "    (X,c) = pickle.load(f)\n",
    "\n",
    "with open(os.path.join(PATH, \"val.pickle\"), \"rb\") as f:\n",
    "    (Xval,cval) = pickle.load(f)\n",
    "    \n",
    "print(\"training alphabets\")\n",
    "print(c.keys())\n",
    "print(\"validation alphabets:\")\n",
    "print(cval.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading data from /home/soren/Desktop/keras-oneshot/train.pickle\n",
      "loading data from /home/soren/Desktop/keras-oneshot/val.pickle\n"
     ]
    }
   ],
   "source": [
    "class Siamese_Loader:\n",
    "    \"\"\"For loading batches and testing tasks to a siamese net\"\"\"\n",
    "    def __init__(self, path, data_subsets = [\"train\", \"val\"]):\n",
    "        self.data = {}\n",
    "        self.categories = {}\n",
    "        self.info = {}\n",
    "        \n",
    "        for name in data_subsets:\n",
    "            file_path = os.path.join(path, name + \".pickle\")\n",
    "            print(\"loading data from {}\".format(file_path))\n",
    "            with open(file_path,\"rb\") as f:\n",
    "                (X,c) = pickle.load(f)\n",
    "                self.data[name] = X\n",
    "                self.categories[name] = c\n",
    "\n",
    "    def get_batch(self,batch_size,s=\"train\"):\n",
    "        \"\"\"Create batch of n pairs, half same class, half different class\"\"\"\n",
    "        X=self.data[s]\n",
    "        n_classes, n_examples, w, h = X.shape\n",
    "\n",
    "        #randomly sample several classes to use in the batch\n",
    "        categories = rng.choice(n_classes,size=(batch_size,),replace=False)\n",
    "        #initialize 2 empty arrays for the input image batch\n",
    "        pairs=[np.zeros((batch_size, h, w,1)) for i in range(2)]\n",
    "        #initialize vector for the targets, and make one half of it '1's, so 2nd half of batch has same class\n",
    "        targets=np.zeros((batch_size,))\n",
    "        targets[batch_size//2:] = 1\n",
    "        for i in range(batch_size):\n",
    "            category = categories[i]\n",
    "            idx_1 = rng.randint(0, n_examples)\n",
    "            pairs[0][i,:,:,:] = X[category, idx_1].reshape(w, h, 1)\n",
    "            idx_2 = rng.randint(0, n_examples)\n",
    "            #pick images of same class for 1st half, different for 2nd\n",
    "            if i >= batch_size // 2:\n",
    "                category_2 = category  \n",
    "            else: \n",
    "                #add a random number to the category modulo n classes to ensure 2nd image has\n",
    "                # ..different category\n",
    "                category_2 = (category + rng.randint(1,n_classes)) % n_classes\n",
    "            pairs[1][i,:,:,:] = X[category_2,idx_2].reshape(w, h,1)\n",
    "        return pairs, targets\n",
    "    \n",
    "    def generate(self, batch_size, s=\"train\"):\n",
    "        \"\"\"a generator for batches, so model.fit_generator can be used. \"\"\"\n",
    "        while True:\n",
    "            pairs, targets = self.get_batch(batch_size,s)\n",
    "            yield (pairs, targets)    \n",
    "\n",
    "    def make_oneshot_task(self,N,s=\"val\",language=None):\n",
    "        \"\"\"Create pairs of test image, support set for testing N way one-shot learning. \"\"\"\n",
    "        X=self.data[s]\n",
    "        n_classes, n_examples, w, h = X.shape\n",
    "        indices = rng.randint(0,n_examples,size=(N,))\n",
    "        if language is not None:\n",
    "            low, high = self.categories[s][language]\n",
    "            if N > high - low:\n",
    "                raise ValueError(\"This language ({}) has less than {} letters\".format(language, N))\n",
    "            categories = rng.choice(range(low,high),size=(N,),replace=False)\n",
    "            \n",
    "        else:#if no language specified just pick a bunch of random letters\n",
    "            categories = rng.choice(range(n_classes),size=(N,),replace=False)            \n",
    "        true_category = categories[0]\n",
    "        ex1, ex2 = rng.choice(n_examples,replace=False,size=(2,))\n",
    "        test_image = np.asarray([X[true_category,ex1,:,:]]*N).reshape(N, w, h,1)\n",
    "        support_set = X[categories,indices,:,:]\n",
    "        support_set[0,:,:] = X[true_category,ex2]\n",
    "        support_set = support_set.reshape(N, w, h,1)\n",
    "        targets = np.zeros((N,))\n",
    "        targets[0] = 1\n",
    "        targets, test_image, support_set = shuffle(targets, test_image, support_set)\n",
    "        pairs = [test_image,support_set]\n",
    "\n",
    "        return pairs, targets\n",
    "    \n",
    "    def test_oneshot(self,model,N,k,s=\"val\",verbose=0):\n",
    "        \"\"\"Test average N way oneshot learning accuracy of a siamese neural net over k one-shot tasks\"\"\"\n",
    "        n_correct = 0\n",
    "        if verbose:\n",
    "            print(\"Evaluating model on {} random {} way one-shot learning tasks ...\".format(k,N))\n",
    "        for i in range(k):\n",
    "            inputs, targets = self.make_oneshot_task(N,s)\n",
    "            probs = model.predict(inputs)\n",
    "            if np.argmax(probs) == np.argmax(targets):\n",
    "                n_correct+=1\n",
    "        percent_correct = (100.0*n_correct / k)\n",
    "        if verbose:\n",
    "            print(\"Got an average of {}% {} way one-shot learning accuracy\".format(percent_correct,N))\n",
    "        return percent_correct\n",
    "    \n",
    "    def train(self, model, epochs, verbosity):\n",
    "        model.fit_generator(self.generate(batch_size),\n",
    "                            \n",
    "                             )\n",
    "    \n",
    "    \n",
    "#Instantiate the class\n",
    "loader = Siamese_Loader(PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAHEAAADnCAYAAAApQbmOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAABvZJREFUeJztnUGW3CAMRHFe7n/lzoqEcbABIUFR\nXX+XzIwN/kjGoHZfn88nibP5tbsBYh5JJEASCZBEAiSRgN+Nn2vqisP19ANFIgGSSIAkEiCJBEgi\nAZJIgCQSIIkESCIBkkhAa9kthOv6uYKkjek5ICLxuq7/xIp+ICRmJNMGlMSMRI4BKTEliRzhakwq\nwmccLVma9PwFdz/x8/m8ilJEttkuMSORdmAkpiSRVqAkpiSRFuAkpiSRo0BKTKktUjL/ASsxJT1e\n9AItMSWJ7AFeotJmmy1bUbMoOn8CHYmKwj5gJT4JVBT+D6xE0Q+kREXhGJASa0jgM3ASNZkZB0qi\n0qgNKInCBoxERaEdGIk1JLAPaImiDwiJtVSqKOwHQqKYY7tEPRfOs11iDaXSMbZK1L3QB8hIFGNs\nk6go9AMmEiXQzhaJmpH6slyi0qg/MOlU2FkqUVEYgyKRgCOLh6M4NVMsi0TNSOPYmk5PGOUnsESi\nojCWcImn186cMAA1Oy04ZWDdCZV4ehRm0KMxTOKpAp/ahyzS5Tmxt4PoAk9F98QKp0WjJBKwTOJp\nqfSkaFQkvnDKwHOR2OrsKRfjVBSJBtBSarjE06PwhPaH7yeijVpGlE4JcJGoaNvLktmpiEXplIDQ\niQ1LhKLfLhSJBEgiAZLYAD2VpuQokeX+1wNaX10jMXeu9f1Pp3BCFKYUkE4Z5J2G7omDIA5SSRwA\nUWBK+lTUK6jS7igSCZBEAiSRAEkkQBIJkEQCJJEASSRAEgmQRAIkkQBJJEASCZBEAiSRAEkkQBIJ\naO3sn1Hu9eUoEgmQRAIkkQBJJEASCZBEAiSRAEkkQBIJkEQCJJEASSRAEgmQRAIkkQBJJEASCZBE\nAiSRgFaNzRmvj/gOHuudFIkvfO1rwa7rOqbzLNC92+2EAeTdRpg3SuWOeQyC67pev7Br90DL7ahh\naRuMxNz4mYtcXpjaRUKQl1K9Hbm9lv67S9yZzjwGwhOtfs2er2z7aPtDIhFlxM9QRkZK6/r0+XyG\nA8F1YuMVAVEz3N72lb/XeouypZ25f173RTeJniksX7hdqXmkH5Y+l3/TEtqD23dFRaSbnSKj8XxP\n+rTE6Cn7iMi3EY06GDxkTk1sVj1z9Z7jhK/FLQdT2Z6Ztpkkrp6xzWIdbFG3iJR+Zo3Z8wyn09ME\npoTZ1vvkZoar0cEfP0RZsmKjMzDmt6IksI7HhKlMsRa6JX6jQI9nuF6WT2y+hbedkJXnayGJBtCy\nksozCJBEAGbTsyQ+sGqZ7r7lZWH4njhyMrR7hxdPS2cW8trw0tlp78mmVyEGOmY5V886a2u/r+e8\nrX54PH9DzE5nOzL6t70lEB6ZpBwMUZkJQmIU94u3a923FrkeaTQTIhFliS5fKJRFe89F7xLYSPS8\n4JZjRVe35WPQ3BPveBZcWY8TGbX3QqzZ/uo50Yh3Ydj2QqkSlFqWyPvybJW6d7uG0umK1XuUSVEE\nb9dv2cP+yIk8VzVGiRoIHrPciHaFTWysM8LZTkZ+DmPmuJEZZvvs1CtivZ8FvTNJ6Gx3pFBKbEWf\n2WdGEgmQRAIkkQBJJGC7RK9VoN5ddpRlQU+2Pyd60buIvHtJb/va6S56O+5dy+L5rpnyb70XJoYX\nwHeP5FFmLlhUX73LRUwli+gi7xGE3N7lNTb3mpXW7+6iHOnIAr1wrzv1vnCoIqJ2SyzHhJ7Y7BC4\nsny/1jeId7t5MHrDr3W85zMOteOvGDRPoqy3gTCJlpmXdbZ2f0b0eDdMFD1l/T2/VzK0YtNKNZYG\n3P92psQQUV45EXy7LjNFzsNvz3j95YkiXbSL70nvfbZRMPZ4kOGdfc/C3pS45ZU4XDe/nf3d5fWn\nEtnXbbsY3yQwmu1bUWIeSSRAEgmQRAJgJTKWUUQBIfEUYajthFwAnyFyQ3jFmzAsuL6MaHfH7qsi\nva868TpXSuML/lvenuG55Nb7fz3nrP397Meoa9ReZzJzDo/M4ZZOR0ZVbe+v5/+ezls7Znku667K\n0x7l079ni7Gsg+H4e2JkRVoW2RooXoPEyvZvqFlRgjGTVnP7ogXOXIfpSEQtZKox2taeorDR8z8d\nd9mLF+7M7ORnVg0Ar0nOzMfAPaO5xCzxaYqd0v5HjSesk5zaMbzwyGRTXzNU4i0uuoTeSpkSZx/8\nvQaD+WE/9G0QgPfZWr89PlOx9GF/1XITksC3W0Tks+0o3RJ3FtWuZvQij9wnI4IB6mEfQWBKcWkx\napBCbEV9C1GDVBIJeE2nqJugKKy8Pm9RrEgkQBIJkEQCJJEASSRAEgmQRAIkkQBJJEASCZBEAiSR\nAEkkQBIJkEQCJJEASSRAEgmQRAJaL+gTB6BIJEASCZBEAiSRAEkkQBIJ+ANUadHyA8XI6QAAAABJ\nRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f157e85e5f8>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "def concat_images(X):\n",
    "    \"\"\"Concatenates a bunch of images into a big matrix for plotting purposes.\"\"\"\n",
    "    nc,h,w,_ = X.shape\n",
    "    X = X.reshape(nc,h,w)\n",
    "    n = np.ceil(np.sqrt(nc)).astype(\"int8\")\n",
    "    img = np.zeros((n*w,n*h))\n",
    "    x = 0\n",
    "    y = 0\n",
    "    for example in range(nc):\n",
    "        img[x*w:(x+1)*w,y*h:(y+1)*h] = X[example]\n",
    "        y += 1\n",
    "        if y >= n:\n",
    "            y = 0\n",
    "            x += 1\n",
    "    return img\n",
    "\n",
    "\n",
    "def plot_oneshot_task(pairs):\n",
    "    \"\"\"Takes a one-shot task given to a siamese net and  \"\"\"\n",
    "    fig,(ax1,ax2) = plt.subplots(2)\n",
    "    ax1.matshow(pairs[0][0].reshape(105,105),cmap='gray')\n",
    "    img = concat_images(pairs[1])\n",
    "    ax1.get_yaxis().set_visible(False)\n",
    "    ax1.get_xaxis().set_visible(False)\n",
    "    ax2.matshow(img,cmap='gray')\n",
    "    plt.xticks([])\n",
    "    plt.yticks([])\n",
    "    plt.show()\n",
    "#example of a one-shot learning task\n",
    "pairs, targets = loader.make_oneshot_task(20,\"train\",\"Japanese_(katakana)\")\n",
    "plot_oneshot_task(pairs)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "!\n",
      "training\n",
      "8.25082"
     ]
    }
   ],
   "source": [
    "\n",
    "#Training loop\n",
    "print(\"!\")\n",
    "evaluate_every = 1 # interval for evaluating on one-shot tasks\n",
    "loss_every=50 # interval for printing loss (iterations)\n",
    "batch_size = 32\n",
    "n_iter = 90000\n",
    "N_way = 20 # how many classes for testing one-shot tasks>\n",
    "n_val = 250 #how mahy one-shot tasks to validate on?\n",
    "best = -1\n",
    "weights_path = os.path.join(PATH, \"weights\")\n",
    "print(\"training\")\n",
    "for i in range(1, n_iter):\n",
    "    (inputs,targets)=loader.get_batch(batch_size)\n",
    "    loss=siamese_net.train_on_batch(inputs,targets)\n",
    "    print(loss)\n",
    "    if i % evaluate_every == 0:\n",
    "        print(\"evaluating\")\n",
    "        val_acc = loader.test_oneshot(siamese_net,N_way,n_val,verbose=True)\n",
    "        if val_acc >= best:\n",
    "            print(\"saving\")\n",
    "            siamese_net.save(weights_path)\n",
    "            best=val_acc\n",
    "\n",
    "    if i % loss_every == 0:\n",
    "        print(\"iteration {}, training loss: {:.2f},\".format(i,loss))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def nearest_neighbour_correct(pairs,targets):\n",
    "    \"\"\"returns 1 if nearest neighbour gets the correct answer for a one-shot task\n",
    "        given by (pairs, targets)\"\"\"\n",
    "    L2_distances = np.zeros_like(targets)\n",
    "    for i in range(len(targets)):\n",
    "        L2_distances[i] = np.sum(np.sqrt(pairs[0][i]**2 - pairs[1][i]**2))\n",
    "    if np.argmin(L2_distances) == np.argmax(targets):\n",
    "        return 1\n",
    "    return 0\n",
    "\n",
    "\n",
    "def test_nn_accuracy(N_ways,n_trials,loader):\n",
    "    \"\"\"Returns accuracy of one shot \"\"\"\n",
    "    print(\"Evaluating nearest neighbour on {} unique {} way one-shot learning tasks ...\".format(n_trials,N_ways))\n",
    "\n",
    "    n_right = 0\n",
    "    \n",
    "    for i in range(n_trials):\n",
    "        pairs,targets = loader.make_oneshot_task(N_ways,\"val\")\n",
    "        correct = nearest_neighbour_correct(pairs,targets)\n",
    "        n_right += correct\n",
    "    return 100.0 * n_right / n_trials\n",
    "\n",
    "\n",
    "ways = np.arange(1, 60, 2)\n",
    "resume =  False\n",
    "val_accs, train_accs,nn_accs = [], [], []\n",
    "trials = 450\n",
    "for N in ways:\n",
    "    val_accs.append(loader.test_oneshot(siamese_net, N,trials, \"val\", verbose=True))\n",
    "    train_accs.append(loader.test_oneshot(siamese_net, N,trials, \"train\", verbose=True))\n",
    "    nn_accs.append(test_nn_accuracy(N,trials, loader))\n",
    "    \n",
    "#plot the accuracy vs num categories for each\n",
    "plt.plot(ways, val_accs, \"m\")\n",
    "plt.plot(ways, train_accs, \"y\")\n",
    "plt.plot(ways, nn_accs, \"c\")\n",
    "\n",
    "plt.plot(ways,100.0/ways,\"r\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "fig,ax = plt.subplots(1)\n",
    "ax.plot(ways,val_accs,\"m\",label=\"Siamese(val set)\")\n",
    "ax.plot(ways,train_accs,\"y\",label=\"Siamese(train set)\")\n",
    "plt.plot(ways,nn_accs,label=\"Nearest neighbour\")\n",
    "\n",
    "ax.plot(ways,100.0/ways,\"g\",label=\"Random guessing\")\n",
    "plt.xlabel(\"Number of possible classes in one-shot tasks\")\n",
    "plt.ylabel(\"% Accuracy\")\n",
    "plt.title(\"Omiglot One-Shot Learning Performance of a Siamese Network\")\n",
    "box = ax.get_position()\n",
    "ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])\n",
    "ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))\n",
    "inputs,targets = loader.make_oneshot_task(20,\"val\")\n",
    "plt.show()\n",
    "\n",
    "print(inputs[0].shape)\n",
    "plot_oneshot_task(inputs)\n",
    "p=siamese_net.predict(inputs)\n",
    "print(p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "\n",
    "a=test_nn_accuracy(3,500,loader)\n",
    "print(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "celltoolbar": "Raw Cell Format",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
